查看原文
其他

【综述专栏】Transformer结构理解及一些细节!

在科学研究中,从方法论上来讲,都应“先见森林,再见树木”。当前,人工智能学术研究方兴未艾,技术迅猛发展,可谓万木争荣,日新月异。对于AI从业者来说,在广袤的知识森林中,系统梳理脉络,才能更好地把握趋势。为此,我们精选国内外优秀的综述文章,开辟“综述专栏”,敬请关注。

来源:知乎—李狗蛋
地址:https://zhuanlan.zhihu.com/p/260159594

01

整体结构
通俗易懂的总结下Transformer的工作机制,先上图:
Encoder的目的
将长度为N的序列的各自初始化的Embedding向量,转换成各自融合了全局语义的N个Embedding,即编码的过程。
Decoder部分
1、Masked多头注意力的目的
更容易理解的解释:传统的Seq2Seq模型使用的多时RNN,RNN的特点是输入数据是单个的,且按照时序的,无论怎样也获取不到来自未来时刻的信息。而Transformer中使用Attention替换RNN,这使得在训练阶段中,在Decoder输入数据时,输入的是一整句,这其中包含了等待被预测的后续的序列信息,因此需要将注意力矩阵中主对角线以上的元素进行Mask。实现Mask的方式很简单,仅需初始化一个下三角矩阵为0,上三角元素均为负无穷的矩阵然后加到注意力矩阵上,因为注意力需要经过Softmax进行归一化,其中e^-inf为0,因此可以将未来信息抹去。
避免在进行序列预测时使模型获取本不应该知道的信息,对从Decoder中获取的前几个时间步的信息进行。
输入:上一时间步的输出Embedding。
输出:经过Masked自注意力机制后的Embedding(已经融合了前几个时间步的预测内容),其含义表示当前需要翻译的内容。
2、多头注意力部分的目的
将Masked注意力产生的向量作为Query,和已经编码好的特征向量(Encoder的输出)作为K和V,计算注意力,从而得到当前需要翻译的内容和特征向量的对应关系,从而表示出当前时间步的状态。
Position Encoding为何直接加到Embedding上?
答:可以从另一个视角来看这件事,相当于给每一个单词索引向量后面拼接一个one-hot位置索引向量,得到[xi,pi],然后再经过Embedding矩阵W转换成每个单词相应的Embedding。根据分块矩阵,其等价于WI*xi+Wp*pi,而前一部分表示的就是单词的Word Embedding语义信息,后一部分表示位置信息。从而我们可以理解为什么直接相加。
在阅读一些代码时遇到一些问题,正好评论区的博主给出了解答,主要涉及mask的问题,即Transformer中的两种mask,padding_mask和sequence_mask。一般的self-attention需要用到padding_mask,masked self_attention除此之外还需要sequence_mask来屏蔽未来信息。
第二张图的解释可能存在一些问题,由于在encoder中padding mask没有对注意力矩阵的行进行mask,只对列进行mask,所以decoder的encode-decode-attention还要进行一次mask的不是因为之前进行了残差连接和归一化,是因为需要计算目标句子中的每个词对源句子中每个词的关注度,但是源语言句子也是经过pad填充的,所以在进行一次mask的原因是将目标句子中的每个词对源语言句子中pad的词的关注度置为负无穷。

02

多头注意力(Multi-head Attention)
多头注意力的实现方式可能有多中,最常见的:通过共享的参数矩阵Wq,Wk,Wv映射到768维度,然后根据头数(比如12),将768维reshape成12×64,然后每个头分别计算Attention,再经过一层Dense(768×768)进行融合。这样实现的参数量为(768×768+768)×4。
另一种是每个头(比如12),都有参数矩阵直接映射到64维,而不是通过reshape的方式。这样的参数量为:[(768×64+64)×12] × 4 (乘以4是因为Q,K,V,Dense)
另外feed-faward包含两层Dense,一层768×3084,第二层再从3084降至768。

03

一些细节
1、scaled dot-product?
为什么采用scaled dot-product可以参考:
https://zhuanlan.zhihu.com/p/391536998
2、add & norm?
通过源码可以看出是每层子结构的输出先做layer norm,在做dropout,然后残差连接。
具体可以表示为: 

这里的sublayer可以是attention或者feedfaword layer。
因此模型整体的流程是:每个子模块的输出与输入先做残差连接,然后做norm输入到下一层。
参考:https://blog.csdn.net/u013510838/article/details/105980363
3、FeedFarword内部结构?
FeedFarword包含两层全连接。更具体地:全连接1 -> Relu -> dropout -> 全连接2

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“综述专栏”历史文章


更多综述专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存